{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp common._base_multivariate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# BaseMultivariate\n",
    "\n",
    "> The `BaseWindows` class contains standard methods shared across window-based multivariate neural networks; in contrast to recurrent neural networks these models commit to a fixed sequence length input."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The standard methods include data preprocessing `_normalization`, optimization utilities like parameter initialization, `training_step`, `validation_step`, and shared `fit` and `predict` methods.These shared methods enable all the `neuralforecast.models` compatibility with the `core.NeuralForecast` wrapper class. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from fastcore.test import test_eq\n",
    "from nbdev.showdoc import show_doc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import pytorch_lightning as pl\n",
    "import neuralforecast.losses.pytorch as losses\n",
    "\n",
    "from neuralforecast.common._base_model import BaseModel\n",
    "from neuralforecast.common._scalers import TemporalNorm\n",
    "from neuralforecast.tsdataset import TimeSeriesDataModule\n",
    "from neuralforecast.utils import get_indexer_raise_missing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class BaseMultivariate(BaseModel):\n",
    "    \"\"\" Base Multivariate\n",
    "    \n",
    "    Base class for all multivariate models. The forecasts for all time-series are produced simultaneously \n",
    "    within each window, which are randomly sampled during training.\n",
    "    \n",
    "    This class implements the basic functionality for all windows-based models, including:\n",
    "    - PyTorch Lightning's methods training_step, validation_step, predict_step.<br>\n",
    "    - fit and predict methods used by NeuralForecast.core class.<br>\n",
    "    - sampling and wrangling methods to generate multivariate windows.\n",
    "    \"\"\"\n",
    "    def __init__(self, \n",
    "                 h,\n",
    "                 input_size,\n",
    "                 loss,\n",
    "                 valid_loss,\n",
    "                 learning_rate,\n",
    "                 max_steps,\n",
    "                 val_check_steps,\n",
    "                 n_series,\n",
    "                 batch_size,\n",
    "                 step_size=1,\n",
    "                 num_lr_decays=0,\n",
    "                 early_stop_patience_steps=-1,\n",
    "                 scaler_type='robust',\n",
    "                 futr_exog_list=None,\n",
    "                 hist_exog_list=None,\n",
    "                 stat_exog_list=None,\n",
    "                 num_workers_loader=0,\n",
    "                 drop_last_loader=False,\n",
    "                 random_seed=1, \n",
    "                 alias=None,\n",
    "                 optimizer=None,\n",
    "                 optimizer_kwargs=None,\n",
    "                 lr_scheduler=None,\n",
    "                 lr_scheduler_kwargs=None,\n",
    "                 **trainer_kwargs):\n",
    "        super().__init__(\n",
    "            random_seed=random_seed,\n",
    "            loss=loss,\n",
    "            valid_loss=valid_loss,\n",
    "            optimizer=optimizer,\n",
    "            optimizer_kwargs=optimizer_kwargs,\n",
    "            lr_scheduler=lr_scheduler,\n",
    "            lr_scheduler_kwargs=lr_scheduler_kwargs,            \n",
    "            futr_exog_list=futr_exog_list,\n",
    "            hist_exog_list=hist_exog_list,\n",
    "            stat_exog_list=stat_exog_list,\n",
    "            max_steps=max_steps,\n",
    "            early_stop_patience_steps=early_stop_patience_steps,\n",
    "            **trainer_kwargs,\n",
    "        )\n",
    "\n",
    "        # Padder to complete train windows, \n",
    "        # example y=[1,2,3,4,5] h=3 -> last y_output = [5,0,0]\n",
    "        self.h = h\n",
    "        self.input_size = input_size\n",
    "        self.n_series = n_series\n",
    "        self.padder = nn.ConstantPad1d(padding=(0, self.h), value=0)\n",
    "\n",
    "        # Multivariate models do not support these loss functions yet.\n",
    "        unsupported_losses = (\n",
    "            losses.sCRPS,\n",
    "            losses.MQLoss,\n",
    "            losses.DistributionLoss,\n",
    "            losses.PMM,\n",
    "            losses.GMM,\n",
    "            losses.HuberMQLoss,\n",
    "            losses.MASE,\n",
    "            losses.relMSE,\n",
    "            losses.NBMM,\n",
    "        )\n",
    "        if isinstance(self.loss, unsupported_losses):\n",
    "            raise Exception(f\"{self.loss} is not supported in a Multivariate model.\")            \n",
    "        if isinstance(self.valid_loss, unsupported_losses):\n",
    "            raise Exception(f\"{self.valid_loss} is not supported in a Multivariate model.\")            \n",
    "\n",
    "        self.batch_size = batch_size\n",
    "        \n",
    "        # Optimization\n",
    "        self.learning_rate = learning_rate\n",
    "        self.max_steps = max_steps\n",
    "        self.num_lr_decays = num_lr_decays\n",
    "        self.lr_decay_steps = max(max_steps // self.num_lr_decays, 1) if self.num_lr_decays > 0 else 10e7\n",
    "        self.early_stop_patience_steps = early_stop_patience_steps\n",
    "        self.val_check_steps = val_check_steps\n",
    "        self.step_size = step_size\n",
    "\n",
    "        # Scaler\n",
    "        self.scaler = TemporalNorm(scaler_type=scaler_type, dim=2) # Time dimension is in the second axis\n",
    "\n",
    "        # Fit arguments\n",
    "        self.val_size = 0\n",
    "        self.test_size = 0\n",
    "\n",
    "        # Model state\n",
    "        self.decompose_forecast = False\n",
    "\n",
    "        # DataModule arguments\n",
    "        self.num_workers_loader = num_workers_loader\n",
    "        self.drop_last_loader = drop_last_loader\n",
    "        # used by on_validation_epoch_end hook\n",
    "        self.validation_step_outputs = []\n",
    "        self.alias = alias\n",
    "\n",
    "    def _create_windows(self, batch, step):\n",
    "        # Parse common data\n",
    "        window_size = self.input_size + self.h\n",
    "        temporal_cols = batch['temporal_cols']\n",
    "        temporal = batch['temporal']\n",
    "\n",
    "        if step == 'train':\n",
    "            if self.val_size + self.test_size > 0:\n",
    "                cutoff = -self.val_size - self.test_size\n",
    "                temporal = temporal[:, :, :cutoff]\n",
    "\n",
    "            temporal = self.padder(temporal)\n",
    "            windows = temporal.unfold(dimension=-1, \n",
    "                                      size=window_size, \n",
    "                                      step=self.step_size)\n",
    "            # [n_series, C, Ws, L+H] 0, 1, 2, 3\n",
    "\n",
    "            # Sample and Available conditions\n",
    "            available_idx = temporal_cols.get_loc('available_mask')\n",
    "            sample_condition = windows[:, available_idx, :, -self.h:]\n",
    "            sample_condition = torch.sum(sample_condition, axis=2) # Sum over time\n",
    "            sample_condition = torch.sum(sample_condition, axis=0) # Sum over time-series\n",
    "            available_condition = windows[:, available_idx, :, :-self.h]\n",
    "            available_condition = torch.sum(available_condition, axis=2) # Sum over time\n",
    "            available_condition = torch.sum(available_condition, axis=0) # Sum over time-series\n",
    "            final_condition = (sample_condition > 0) & (available_condition > 0) # Of shape [Ws]\n",
    "            windows = windows[:, :, final_condition, :]\n",
    "\n",
    "            # Get Static data\n",
    "            static = batch.get('static', None)\n",
    "            static_cols = batch.get('static_cols', None)\n",
    "\n",
    "            # Protection of empty windows\n",
    "            if final_condition.sum() == 0:\n",
    "                raise Exception('No windows available for training')\n",
    "\n",
    "            # Sample windows\n",
    "            n_windows = windows.shape[2]\n",
    "            if self.batch_size is not None:\n",
    "                w_idxs = np.random.choice(n_windows, \n",
    "                                          size=self.batch_size,\n",
    "                                          replace=(n_windows < self.batch_size))\n",
    "                windows = windows[:, :, w_idxs, :]\n",
    "\n",
    "            windows = windows.permute(2, 1, 3, 0) # [Ws, C, L+H, n_series]\n",
    "\n",
    "            windows_batch = dict(temporal=windows,\n",
    "                                 temporal_cols=temporal_cols,\n",
    "                                 static=static,\n",
    "                                 static_cols=static_cols)\n",
    "\n",
    "            return windows_batch\n",
    "\n",
    "        elif step in ['predict', 'val']:\n",
    "\n",
    "            if step == 'predict':\n",
    "                predict_step_size = self.predict_step_size\n",
    "                cutoff = - self.input_size - self.test_size\n",
    "                temporal = batch['temporal'][:, :, cutoff:]\n",
    "\n",
    "            elif step == 'val':\n",
    "                predict_step_size = self.step_size\n",
    "                cutoff = -self.input_size - self.val_size - self.test_size\n",
    "                if self.test_size > 0:\n",
    "                    temporal = batch['temporal'][:, :, cutoff:-self.test_size]\n",
    "                else:\n",
    "                    temporal = batch['temporal'][:, :, cutoff:]\n",
    "\n",
    "            if (step=='predict') and (self.test_size==0) and (len(self.futr_exog_list)==0):\n",
    "                temporal = self.padder(temporal)\n",
    "\n",
    "            windows = temporal.unfold(dimension=-1,\n",
    "                                      size=window_size,\n",
    "                                      step=predict_step_size)\n",
    "            # [n_series, C, Ws, L+H] -> [Ws, C, L+H, n_series]\n",
    "            windows = windows.permute(2, 1, 3, 0)\n",
    "\n",
    "            # Get Static data\n",
    "            static = batch.get('static', None)\n",
    "            static_cols=batch.get('static_cols', None)\n",
    "\n",
    "            windows_batch = dict(temporal=windows,\n",
    "                                 temporal_cols=temporal_cols,\n",
    "                                 static=static,\n",
    "                                 static_cols=static_cols)\n",
    "\n",
    "\n",
    "            return windows_batch\n",
    "        else:\n",
    "            raise ValueError(f'Unknown step {step}') \n",
    "\n",
    "    def _normalization(self, windows, y_idx):\n",
    "        \n",
    "        # windows are already filtered by train/validation/test\n",
    "        # from the `create_windows_method` nor leakage risk\n",
    "        temporal = windows['temporal']                  # [Ws, C, L+H, n_series]\n",
    "        temporal_cols = windows['temporal_cols'].copy() # [Ws, C, L+H, n_series]\n",
    "\n",
    "        # To avoid leakage uses only the lags\n",
    "        temporal_data_cols = self._get_temporal_exogenous_cols(temporal_cols=temporal_cols)\n",
    "        temporal_idxs = get_indexer_raise_missing(temporal_cols, temporal_data_cols)\n",
    "        temporal_idxs = np.append(y_idx, temporal_idxs)\n",
    "        temporal_data = temporal[:, temporal_idxs, :, :]\n",
    "        temporal_mask = temporal[:, temporal_cols.get_loc('available_mask'), :, :].clone()\n",
    "        temporal_mask[:, -self.h:, :] = 0.0\n",
    "\n",
    "        # Normalize. self.scaler stores the shift and scale for inverse transform\n",
    "        temporal_mask = temporal_mask.unsqueeze(1) # Add channel dimension for scaler.transform.\n",
    "        temporal_data = self.scaler.transform(x=temporal_data, mask=temporal_mask)\n",
    "        # Replace values in windows dict\n",
    "        temporal[:, temporal_idxs, :, :] = temporal_data\n",
    "        windows['temporal'] = temporal\n",
    "\n",
    "        return windows\n",
    "\n",
    "    def _inv_normalization(self, y_hat, temporal_cols, y_idx):\n",
    "        # Receives window predictions [Ws, H, n_series]\n",
    "        # Broadcasts outputs and inverts normalization\n",
    "\n",
    "        # Add C dimension\n",
    "        # if y_hat.ndim == 2:\n",
    "        #     remove_dimension = True\n",
    "        #     y_hat = y_hat.unsqueeze(-1)\n",
    "        # else:\n",
    "        #     remove_dimension = False\n",
    "        \n",
    "        y_scale = self.scaler.x_scale[:, [y_idx], :].squeeze(1)\n",
    "        y_loc = self.scaler.x_shift[:, [y_idx], :].squeeze(1)\n",
    "\n",
    "        # y_scale = torch.repeat_interleave(y_scale, repeats=y_hat.shape[-1], dim=-1)\n",
    "        # y_loc = torch.repeat_interleave(y_loc, repeats=y_hat.shape[-1], dim=-1)\n",
    "\n",
    "        y_hat = self.scaler.inverse_transform(z=y_hat, x_scale=y_scale, x_shift=y_loc)\n",
    "\n",
    "        # if remove_dimension:\n",
    "        #     y_hat = y_hat.squeeze(-1)\n",
    "        #     y_loc = y_loc.squeeze(-1)\n",
    "        #     y_scale = y_scale.squeeze(-1)\n",
    "\n",
    "        return y_hat, y_loc, y_scale\n",
    "\n",
    "    def _parse_windows(self, batch, windows):\n",
    "        # Temporal: [Ws, C, L+H, n_series]\n",
    "\n",
    "        # Filter insample lags from outsample horizon\n",
    "        mask_idx = batch['temporal_cols'].get_loc('available_mask')\n",
    "        y_idx = batch['y_idx']        \n",
    "        insample_y = windows['temporal'][:, y_idx, :-self.h, :]\n",
    "        insample_mask = windows['temporal'][:, mask_idx, :-self.h, :]\n",
    "        outsample_y = windows['temporal'][:, y_idx, -self.h:, :]\n",
    "        outsample_mask = windows['temporal'][:, mask_idx, -self.h:, :]\n",
    "\n",
    "        # Filter historic exogenous variables\n",
    "        if len(self.hist_exog_list):\n",
    "            hist_exog_idx = get_indexer_raise_missing(windows['temporal_cols'], self.hist_exog_list)\n",
    "            hist_exog = windows['temporal'][:, hist_exog_idx, :-self.h, :]\n",
    "        else:\n",
    "            hist_exog = None\n",
    "        \n",
    "        # Filter future exogenous variables\n",
    "        if len(self.futr_exog_list):\n",
    "            futr_exog_idx = get_indexer_raise_missing(windows['temporal_cols'], self.futr_exog_list)\n",
    "            futr_exog = windows['temporal'][:, futr_exog_idx, :, :]\n",
    "        else:\n",
    "            futr_exog = None\n",
    "\n",
    "        # Filter static variables\n",
    "        if len(self.stat_exog_list):\n",
    "            static_idx = get_indexer_raise_missing(windows['static_cols'], self.stat_exog_list)\n",
    "            stat_exog = windows['static'][:, static_idx]\n",
    "        else:\n",
    "            stat_exog = None\n",
    "\n",
    "        return insample_y, insample_mask, outsample_y, outsample_mask, \\\n",
    "               hist_exog, futr_exog, stat_exog\n",
    "\n",
    "    def training_step(self, batch, batch_idx):        \n",
    "        # Create and normalize windows [batch_size, n_series, C, L+H]\n",
    "        windows = self._create_windows(batch, step='train')\n",
    "        y_idx = batch['y_idx']\n",
    "        windows = self._normalization(windows=windows, y_idx=y_idx)\n",
    "\n",
    "        # Parse windows\n",
    "        insample_y, insample_mask, outsample_y, outsample_mask, \\\n",
    "               hist_exog, futr_exog, stat_exog = self._parse_windows(batch, windows)\n",
    "\n",
    "        windows_batch = dict(insample_y=insample_y, # [Ws, L, n_series]\n",
    "                             insample_mask=insample_mask, # [Ws, L, n_series]\n",
    "                             futr_exog=futr_exog, # [Ws, F, L + h, n_series]\n",
    "                             hist_exog=hist_exog, # [Ws, X, L, n_series]\n",
    "                             stat_exog=stat_exog) # [n_series, S]\n",
    "\n",
    "        # Model Predictions\n",
    "        output = self(windows_batch)\n",
    "        if self.loss.is_distribution_output:\n",
    "            outsample_y, y_loc, y_scale = self._inv_normalization(y_hat=outsample_y,\n",
    "                                            temporal_cols=batch['temporal_cols'],\n",
    "                                            y_idx=y_idx)\n",
    "            distr_args = self.loss.scale_decouple(output=output, loc=y_loc, scale=y_scale)\n",
    "            loss = self.loss(y=outsample_y, distr_args=distr_args, mask=outsample_mask)\n",
    "        else:\n",
    "            loss = self.loss(y=outsample_y, y_hat=output, mask=outsample_mask)\n",
    "\n",
    "        if torch.isnan(loss):\n",
    "            print('Model Parameters', self.hparams)\n",
    "            print('insample_y', torch.isnan(insample_y).sum())\n",
    "            print('outsample_y', torch.isnan(outsample_y).sum())\n",
    "            print('output', torch.isnan(output).sum())\n",
    "            raise Exception('Loss is NaN, training stopped.')\n",
    "\n",
    "        self.log(\n",
    "            'train_loss',\n",
    "            loss.item(),\n",
    "            batch_size=outsample_y.size(0),\n",
    "            prog_bar=True,\n",
    "            on_epoch=True,\n",
    "        )\n",
    "        self.train_trajectories.append((self.global_step, loss.item()))\n",
    "        return loss\n",
    "\n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        if self.val_size == 0:\n",
    "            return np.nan\n",
    "        \n",
    "        # Create and normalize windows [Ws, L+H, C]\n",
    "        windows = self._create_windows(batch, step='val')\n",
    "        y_idx = batch['y_idx']\n",
    "        windows = self._normalization(windows=windows, y_idx=y_idx)\n",
    "\n",
    "        # Parse windows\n",
    "        insample_y, insample_mask, outsample_y, outsample_mask, \\\n",
    "               hist_exog, futr_exog, stat_exog = self._parse_windows(batch, windows)\n",
    "\n",
    "        windows_batch = dict(insample_y=insample_y, # [Ws, L, n_series]\n",
    "                             insample_mask=insample_mask, # [Ws, L, n_series]\n",
    "                             futr_exog=futr_exog, # [Ws, F, L + h, n_series]\n",
    "                             hist_exog=hist_exog, # [Ws, X, L, n_series]\n",
    "                             stat_exog=stat_exog) # [n_series, S]\n",
    "\n",
    "        # Model Predictions\n",
    "        output = self(windows_batch)\n",
    "        if self.loss.is_distribution_output:\n",
    "            outsample_y, y_loc, y_scale = self._inv_normalization(y_hat=outsample_y,\n",
    "                                            temporal_cols=batch['temporal_cols'],\n",
    "                                            y_idx=y_idx)\n",
    "            distr_args = self.loss.scale_decouple(output=output, loc=y_loc, scale=y_scale)\n",
    "\n",
    "            if str(type(self.valid_loss)) in\\\n",
    "                [\"<class 'neuralforecast.losses.pytorch.sCRPS'>\", \"<class 'neuralforecast.losses.pytorch.MQLoss'>\"]:\n",
    "                _, output = self.loss.sample(distr_args=distr_args)\n",
    "\n",
    "        # Validation Loss evaluation\n",
    "        if self.valid_loss.is_distribution_output:\n",
    "            valid_loss = self.valid_loss(y=outsample_y, distr_args=distr_args, mask=outsample_mask)\n",
    "        else:\n",
    "            valid_loss = self.valid_loss(y=outsample_y, y_hat=output, mask=outsample_mask)\n",
    "\n",
    "        if torch.isnan(valid_loss):\n",
    "            raise Exception('Loss is NaN, training stopped.')\n",
    "\n",
    "        self.log(\n",
    "            'valid_loss',\n",
    "            valid_loss.item(),\n",
    "            batch_size=outsample_y.size(0),\n",
    "            prog_bar=True,\n",
    "            on_epoch=True,\n",
    "        )\n",
    "        self.validation_step_outputs.append(valid_loss)\n",
    "        return valid_loss\n",
    "\n",
    "    def predict_step(self, batch, batch_idx):        \n",
    "        # Create and normalize windows [Ws, L+H, C]\n",
    "        windows = self._create_windows(batch, step='predict')\n",
    "        y_idx = batch['y_idx']        \n",
    "        windows = self._normalization(windows=windows, y_idx=y_idx)\n",
    "\n",
    "        # Parse windows\n",
    "        insample_y, insample_mask, _, _, \\\n",
    "               hist_exog, futr_exog, stat_exog = self._parse_windows(batch, windows)\n",
    "\n",
    "        windows_batch = dict(insample_y=insample_y, # [Ws, L, n_series]\n",
    "                             insample_mask=insample_mask, # [Ws, L, n_series]\n",
    "                             futr_exog=futr_exog, # [Ws, F, L + h, n_series]\n",
    "                             hist_exog=hist_exog, # [Ws, X, L, n_series]\n",
    "                             stat_exog=stat_exog) # [n_series, S]\n",
    "\n",
    "        # Model Predictions\n",
    "        output = self(windows_batch)\n",
    "        if self.loss.is_distribution_output:\n",
    "            _, y_loc, y_scale = self._inv_normalization(y_hat=torch.empty(size=(insample_y.shape[0], \n",
    "                                                                                self.h, \n",
    "                                                                                self.n_series),\n",
    "                                                            dtype=output[0].dtype,\n",
    "                                                            device=output[0].device),\n",
    "                                            temporal_cols=batch['temporal_cols'],\n",
    "                                            y_idx=y_idx)\n",
    "            distr_args = self.loss.scale_decouple(output=output, loc=y_loc, scale=y_scale)\n",
    "            _, y_hat = self.loss.sample(distr_args=distr_args)\n",
    "\n",
    "            if self.loss.return_params:\n",
    "                distr_args = torch.stack(distr_args, dim=-1)\n",
    "                distr_args = torch.reshape(distr_args, (len(windows[\"temporal\"]), self.h, -1))\n",
    "                y_hat = torch.concat((y_hat, distr_args), axis=2)\n",
    "        else:\n",
    "            y_hat, _, _ = self._inv_normalization(y_hat=output,\n",
    "                                            temporal_cols=batch['temporal_cols'],\n",
    "                                            y_idx=y_idx)\n",
    "        return y_hat\n",
    "    \n",
    "    def fit(self, dataset, val_size=0, test_size=0, random_seed=None, distributed_config=None):\n",
    "        \"\"\" Fit.\n",
    "\n",
    "        The `fit` method, optimizes the neural network's weights using the\n",
    "        initialization parameters (`learning_rate`, `windows_batch_size`, ...)\n",
    "        and the `loss` function as defined during the initialization. \n",
    "        Within `fit` we use a PyTorch Lightning `Trainer` that\n",
    "        inherits the initialization's `self.trainer_kwargs`, to customize\n",
    "        its inputs, see [PL's trainer arguments](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).\n",
    "\n",
    "        The method is designed to be compatible with SKLearn-like classes\n",
    "        and in particular to be compatible with the StatsForecast library.\n",
    "\n",
    "        By default the `model` is not saving training checkpoints to protect \n",
    "        disk memory, to get them change `enable_checkpointing=True` in `__init__`.\n",
    "\n",
    "        **Parameters:**<br>\n",
    "        `dataset`: NeuralForecast's `TimeSeriesDataset`, see [documentation](https://nixtla.github.io/neuralforecast/tsdataset.html).<br>\n",
    "        `val_size`: int, validation size for temporal cross-validation.<br>\n",
    "        `test_size`: int, test size for temporal cross-validation.<br>\n",
    "        \"\"\"\n",
    "        if distributed_config is not None:\n",
    "            raise ValueError(\"multivariate models cannot be trained using distributed data parallel.\")\n",
    "        return self._fit(\n",
    "            dataset=dataset,\n",
    "            batch_size=self.n_series,\n",
    "            valid_batch_size=self.n_series,\n",
    "            val_size=val_size,\n",
    "            test_size=test_size,\n",
    "            random_seed=random_seed,\n",
    "            shuffle_train=False,\n",
    "            distributed_config=None,\n",
    "        )\n",
    "\n",
    "    def predict(self, dataset, test_size=None, step_size=1, random_seed=None, **data_module_kwargs):\n",
    "        \"\"\" Predict.\n",
    "\n",
    "        Neural network prediction with PL's `Trainer` execution of `predict_step`.\n",
    "\n",
    "        **Parameters:**<br>\n",
    "        `dataset`: NeuralForecast's `TimeSeriesDataset`, see [documentation](https://nixtla.github.io/neuralforecast/tsdataset.html).<br>\n",
    "        `test_size`: int=None, test size for temporal cross-validation.<br>\n",
    "        `step_size`: int=1, Step size between each window.<br>\n",
    "        `**data_module_kwargs`: PL's TimeSeriesDataModule args, see [documentation](https://pytorch-lightning.readthedocs.io/en/1.6.1/extensions/datamodules.html#using-a-datamodule).\n",
    "        \"\"\"\n",
    "        self._check_exog(dataset)\n",
    "        self._restart_seed(random_seed)\n",
    "        data_module_kwargs = self._set_quantile_for_iqloss(**data_module_kwargs)\n",
    "\n",
    "        self.predict_step_size = step_size\n",
    "        self.decompose_forecast = False\n",
    "        datamodule = TimeSeriesDataModule(dataset=dataset, \n",
    "                                          valid_batch_size=self.n_series,                                           \n",
    "                                          batch_size=self.n_series,\n",
    "                                          **data_module_kwargs)\n",
    "\n",
    "        # Protect when case of multiple gpu. PL does not support return preds with multiple gpu.\n",
    "        pred_trainer_kwargs = self.trainer_kwargs.copy()\n",
    "        if (pred_trainer_kwargs.get('accelerator', None) == \"gpu\") and (torch.cuda.device_count() > 1):\n",
    "            pred_trainer_kwargs['devices'] = [0]\n",
    "\n",
    "        trainer = pl.Trainer(**pred_trainer_kwargs)\n",
    "        fcsts = trainer.predict(self, datamodule=datamodule)\n",
    "        fcsts = torch.vstack(fcsts).numpy()\n",
    "\n",
    "        fcsts = np.transpose(fcsts, (2,0,1))\n",
    "        fcsts = fcsts.flatten()\n",
    "        fcsts = fcsts.reshape(-1, len(self.loss.output_names))\n",
    "        return fcsts\n",
    "\n",
    "    def decompose(self, dataset, step_size=1, random_seed=None, **data_module_kwargs):\n",
    "        raise NotImplementedError('decompose')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from fastcore.test import test_fail"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "# test unsupported losses\n",
    "test_fail(\n",
    "    lambda: BaseMultivariate(\n",
    "        h=1,\n",
    "        input_size=1,\n",
    "        loss=losses.MQLoss(),\n",
    "        valid_loss=losses.RMSE(),\n",
    "        learning_rate=1,\n",
    "        max_steps=1,\n",
    "        val_check_steps=1,\n",
    "        n_series=1,\n",
    "        batch_size=1,\n",
    "    ),\n",
    "    contains='MQLoss() is not supported'\n",
    ")\n",
    "\n",
    "test_fail(\n",
    "    lambda: BaseMultivariate(\n",
    "        h=1,\n",
    "        input_size=1,\n",
    "        loss=losses.RMSE(),\n",
    "        valid_loss=losses.MASE(seasonality=1),\n",
    "        learning_rate=1,\n",
    "        max_steps=1,\n",
    "        val_check_steps=1,\n",
    "        n_series=1,\n",
    "        batch_size=1,\n",
    "    ),\n",
    "    contains='MASE() is not supported'\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
